import sys
import time
import traceback
import pathlib
import json
import glob
import os
import datetime
from quant import utils
from threading import Thread, Lock


class Iota:
    def __init__(self):
        self.lock = Lock()
        self.count = 0

    def next(self):
        with self.lock:
            self.count += 1
            return self.count


class Saving:
    def __init__(self, name, path=None):
        self.name = name

        if path is None:
            self._file_name = pathlib.Path(__file__).parent.joinpath('savings', '{}.json'.format(name))
        else:
            self._file_name = pathlib.Path(path).joinpath('{}.json'.format(name))

        if not self._file_name.parent.is_dir():
            os.mkdir(self._file_name.parent.as_posix())

        self._dict = {}
        self._load()

    def __repr__(self):
        return 'Saving({})'.format(self._dict)

    def _load(self):
        is_file = self._file_name.is_file()
        if not is_file:
            self._save()

        fp = open(self._file_name)
        try:
            self._dict = json.load(fp)
        except:
            catch_exception()
            self._dict = {}

    def _save(self):
        fp = open(self._file_name, 'w')
        json.dump(self._dict, fp)

    def get(self, name):
        return self._dict[name]

    def get_default(self, name, default=None):
        return self._dict.get(name, default)

    def set(self, name, value):
        self._dict[name] = value
        self._save()

    def update(self, dic):
        self._dict.update(dic)
        self._save()

    def override(self, dic):
        self._dict.clear()
        self._dict.update(dic)
        self._save()


class AccurateTime:  # not necessary. not fully checked!
    def __init__(self):
        self.add = time.time() - time.perf_counter()

    def time(self):
        return time.perf_counter() + self.add


class AccurateTime2:  # not necessary. not fully checked!
    def __init__(self):
        self.add = time.time() - time.perf_counter()
        self.last = 0

    def time(self):
        result = time.perf_counter() + self.add

        diff = time.time() - result
        if abs(diff) > 1:
            # print('big diff:', diff)
            perf = time.perf_counter()
            now = time.time()
            self.add = now - perf
            result = now

        if result < self.last:
            # print('this < last', self.last, result)
            result = self.last

        self.last = result
        return result


def perf_test(func, args=(), n_times=1000):
    t1 = time.perf_counter()
    for i in range(n_times):
        func(*args)
    t2 = time.perf_counter()
    print(f'{n_times} times consume {round(t2-t1, 6)}s')
    every = (t2-t1)/n_times * 1000000
    every = round(every, 3)
    print(f'result: {every}us per call')


def perf_intv(name=None):
    global _perf
    now = time.perf_counter()
    if name is None:
        _perf = now
    else:
        intv = now - _perf

        header = '{} consume'.format(name.capitalize())
        print('{}: {:.1f}us'.format(header, intv * 1_000_000))


def perf_intv_2(key, name=None):
    now = time.perf_counter()
    if name is None:
        _perf_map[key] = now
    else:
        if key not in _perf_map:
            return
        intv = now - _perf_map[key]
        header = '{} consume'.format(name.capitalize())
        print('{}: {:.1f}us'.format(header, intv * 1_000_000))


def catch_exception():
    try:
        exctype, value, tb = sys.exc_info()
        tb_stack = traceback.extract_tb(tb, )

        chain = [
            '{} line {}: {}'.format(
                stack.filename.split('/')[-1].split('\\')[-1],
                stack.lineno,
                stack.line,
            )
            for stack in tb_stack
        ]

        result = '\n'.join(chain) + '\n{}'.format(value)
        utils.logging.error(result)
    except:
        print('------------Catching Failed Exception------------')
        traceback.print_exc()


def abstract_method(func):
    def new_func(*args, **kwargs):
        err = 'Abstract method {} not callable!'.format(str(func))
        raise NotImplementedError(err)
    return new_func


def set_test_mode():
    utils.logging.set_level(4)
    utils.logging.warn('test mode is set!')


def data_routing_key(event, exchange, symbol):
    return f'{event}.{exchange}.{symbol}'


def except_caught_fn(func):
    def new_fn(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            catch_exception()
    return new_fn


def read_files(path):
    form = '{}/*.py'.format(path)
    result = glob.glob(form)
    result = [file.split(os.sep)[-1].replace('.py', '') for file in result]
    return result


def thread_call(func, *args, **kwargs):
    func = except_caught_fn(func)
    thread = Thread(target=func, args=args, kwargs=kwargs, daemon=True)
    thread.start()


def accurate_time():  # not necessary. not fully checked!
    return _accurate_time_instance.time()


def partial_function(fn, *args):
    def new_fn():
        return fn(*args)
    return new_fn


def get_logs_path():
    path = pathlib.Path(__file__).parent.parent.parent.joinpath('logs')
    return path


def max_(a, b):
    if a > b:
        return a
    return b


def min_(a, b):
    if a > b:
        return b
    return a


def get_beijing_dt():
    dt = datetime.datetime.now(_tz)
    return dt


_perf = 0
_perf_map = {}
_accurate_time_instance = AccurateTime()
_tz = datetime.timezone(datetime.timedelta(hours=8))

if __name__ == '__main__':
    set_test_mode()

    def try_saving():
        a = Saving('s1')
        # a.set('x', 2)
        print(a.get('x'))

        # print(a.get('y'))
        a.set('z', 3)
        a.set('xx', [1, 2, 4])
        a.update({
            'x': 11,
            'z': 33,
            'xx': [4, 2, 8],
        })

    def fn1(a, b, c):
        print('fn1({}, {}, {})'.format(a, b, c))
        return a + b + c

    def fn2():
        print('fn2()')

    def aaa(x, y):
        print(x, y)

    print(get_beijing_dt())
















